import random
import math
from PIL import Image

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from datasets import register
from utils import to_pixel_samples




def resize_fn(img, size):
    # 将 img 转换为 PIL 图像
    if not isinstance(img, Image.Image):
        img = transforms.ToPILImage()(img)
    
    # 使用 Resize 转换大小，并指定 BICUBIC 插值模式
    img = transforms.Resize(size, InterpolationMode.BICUBIC)(img)
    
    # 将 PIL 图像转换为 Tensor
    return transforms.ToTensor()(img)


@register('sr-implicit-paired')
class SRImplicitPaired(Dataset):

    def __init__(self, dataset, inp_size=None, augment=False, sample_q=None):
        self.dataset = dataset
        self.inp_size = inp_size
        self.augment = augment
        self.sample_q = sample_q

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_lr, img_hr = self.dataset[idx]

        s = img_hr.shape[-2] // img_lr.shape[-2]  # assume int scale
        if self.inp_size is None:
            h_lr, w_lr = img_lr.shape[-2:]
            img_hr = img_hr[:, :h_lr * s, :w_lr * s]
            crop_lr, crop_hr = img_lr, img_hr
        else:
            
            w_lr = self.inp_size
            h_lr = self.inp_size
            x0 = random.randint(0, img_lr.shape[-2] - h_lr)
            y0 = random.randint(0, img_lr.shape[-1] - w_lr)
            crop_lr = img_lr[:, x0: x0 + h_lr, y0: y0 + w_lr]
            w_hr = w_lr * s
            h_hr = h_lr * s
            x1 = x0 * s
            y1 = y0 * s
            crop_hr = img_hr[:, x1: x1 + h_hr, y1: y1 + w_hr]
            
        if self.augment:
            hflip = random.random() < 0.5
            vflip = random.random() < 0.5
            dflip = random.random() < 0.5

            def augment(x):
                if hflip:
                    x = x.flip(-2)
                if vflip:
                    x = x.flip(-1)
                if dflip:
                    x = x.transpose(-2, -1)
                return x

            crop_lr = augment(crop_lr)
            crop_hr = augment(crop_hr)

        hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous())

        if self.sample_q is not None:
            sample_lst = np.random.choice(
                len(hr_coord), self.sample_q, replace=False)
            hr_coord = hr_coord[sample_lst]
            hr_rgb = hr_rgb[sample_lst]

        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / crop_hr.shape[-2]
        cell[:, 1] *= 2 / crop_hr.shape[-1]

        return {
            'inp': crop_lr,
            'coord': hr_coord,
            'cell': cell,
            'gt': crop_hr,
            'scale': s,
        }


def resize_fn(img, size):
    # 将 img 转换为 PIL 图像
    if not isinstance(img, Image.Image):
        img = transforms.ToPILImage()(img)
    
    # 使用 Resize 转换大小，并指定 BICUBIC 插值模式
    img = transforms.Resize(size, InterpolationMode.BICUBIC)(img)
    
    # 将 PIL 图像转换为 Tensor
    return transforms.ToTensor()(img)


@register('sr-implicit-downsampled')
class SRImplicitDownsampled(Dataset):

    def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None,
                 augment=False, sample_q=None, batch_per_gpu=16):
        self.dataset = dataset
        self.inp_size = inp_size
        self.scale_min = scale_min
        if scale_max is None:
            scale_max = scale_min
        self.scale_max = scale_max
        self.augment = augment
        self.last_s = random.uniform(self.scale_min, self.scale_max)
        self.batch_per_gpu = batch_per_gpu
        self.call_count = -2

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        self.call_count += 1
        if self.call_count % self.batch_per_gpu == 0:
            s = random.uniform(self.scale_min, self.scale_max)
            self.last_s = s
        else:
            s = self.last_s
            
        img = self.dataset[idx]

        if self.inp_size is None:
            h_lr = math.floor(img.shape[-2] / s + 1e-9)
            w_lr = math.floor(img.shape[-1] / s + 1e-9)
            img = img[:, :round(h_lr * s), :round(w_lr * s)]  # assume round int
            img_down = resize_fn(img, (h_lr, w_lr))
            crop_lr, crop_hr = img_down, img
        else:
            # w_lr = self.inp_size # + random.randint(0, 100)
            # h_lr = self.inp_size # + random.randint(0, 100)
            # w_hr = round(w_lr * s)
            # h_hr = round(h_lr * s)
            # x0 = random.randint(0, img.shape[-2] - h_hr)
            # y0 = random.randint(0, img.shape[-1] - w_hr)
            # crop_hr = img[:, x0: x0 + h_hr, y0: y0 + w_hr]
            # crop_lr = resize_fn(crop_hr, (h_lr,w_lr))
            
            h_lr = math.floor(self.inp_size / s + 1e-9)
            w_lr = math.floor(self.inp_size / s + 1e-9)
            
            w_hr = round(w_lr * s)
            h_hr = round(h_lr * s)
            x0 = random.randint(0, img.shape[-2] - h_hr)
            y0 = random.randint(0, img.shape[-1] - w_hr)
            crop_hr = img[:, x0: x0 + h_hr, y0: y0 + w_hr]
            crop_lr = resize_fn(crop_hr, (h_lr,w_lr))

        if self.augment:
            hflip = random.random() < 0.5
            vflip = random.random() < 0.5
            dflip = random.random() < 0.5

            def augment(x):
                if hflip:
                    x = x.flip(-2)
                if vflip:
                    x = x.flip(-1)
                if dflip:
                    x = x.transpose(-2, -1)
                return x

            crop_lr = augment(crop_lr)
            crop_hr = augment(crop_hr)


        return {
            'inp': crop_lr,
            'gt': crop_hr,
            'scale': s,
        }


@register('sr-implicit-uniform-varied')
class SRImplicitUniformVaried(Dataset):

    def __init__(self, dataset, size_min, size_max=None,
                 augment=False, gt_resize=None, sample_q=None):
        self.dataset = dataset
        self.size_min = size_min
        if size_max is None:
            size_max = size_min
        self.size_max = size_max
        self.augment = augment
        self.gt_resize = gt_resize
        self.sample_q = sample_q

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_lr, img_hr = self.dataset[idx]
        p = idx / (len(self.dataset) - 1)
        w_hr = round(self.size_min + (self.size_max - self.size_min) * p)
        img_hr = resize_fn(img_hr, w_hr)

        if self.augment:
            if random.random() < 0.5:
                img_lr = img_lr.flip(-1)
                img_hr = img_hr.flip(-1)

        if self.gt_resize is not None:
            img_hr = resize_fn(img_hr, self.gt_resize)

        hr_coord, hr_rgb = to_pixel_samples(img_hr)

        if self.sample_q is not None:
            sample_lst = np.random.choice(
                len(hr_coord), self.sample_q, replace=False)
            hr_coord = hr_coord[sample_lst]
            hr_rgb = hr_rgb[sample_lst]

        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / img_hr.shape[-2]
        cell[:, 1] *= 2 / img_hr.shape[-1]

        return {
            'inp': img_lr,
            'coord': hr_coord,
            'cell': cell,
            'gt': hr_rgb
        }